
pacman::p_load(tidyverse,
               fst,
               tidymodels,
               doParallel,
               tictoc,
               logr
)

v_lasso_vars <- read_csv(paste0("../../data/pipeline_outputs/", 
                                SPECIAL_SUFFIX, "/lasso_selected_vars_", 
                                TRAINING_SUFFIX, "_",
                                RAND_NO_CID_SMALLEST_LARGEST, "/", "df_lasso_vars_and_dict.csv")) %>% 
  pull(Variable)

v_xgb_hyperparameter_rand_no_cid <- 0:9 %>% 
  paste(.) %>% 
  str_pad(., 4, "left", 0)

df_training_raw <- read_fst(paste0("../../data/pipeline_outputs/", 
                                   SPECIAL_SUFFIX, "/", "train_",
                                   TRAINING_SUFFIX, "_", RAND_NO_CID_SMALLEST_LARGEST, "/",
                                   "training.fst"), 
                            columns = c("cid", "qtr", "t_default", "consumer_category", "rand_no_cid", all_of(v_lasso_vars))) %>% 
  filter(consumer_category %in% CONSUMER_GROUP,
         rand_no_cid %in% v_xgb_hyperparameter_rand_no_cid,
         qtr %in% c(TRAINING_START_QTR, TRAINING_END_QTR)) %>% 
  select(-consumer_category)
  

df_outcome <- df_training_raw %>% 
  mutate(
    year = year(qtr),
    year_outcome = (10 * year) + t_default,
    t_default = factor(t_default, levels = c(0, 1))
  ) 

df_training <- df_outcome %>% 
  filter(qtr == TRAINING_START_QTR)

df_validation <- df_outcome %>% 
  filter(qtr == TRAINING_END_QTR)

f_xgb_dummies_ <- df_training %>% 
  select(cid, qtr, rand_no_cid, year, year_outcome, all_of(v_lasso_vars)) %>% 
  names() %>% 
  reformulate(termlabels = ., response = "t_default")

xgb_dummies_recipe <- recipe(formula = f_xgb_dummies_, data = df_training) %>% 
  update_role("cid", "qtr", "rand_no_cid", "year", "year_outcome", new_role = "id") %>% 
  step_novel(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors(), one_hot = FALSE) %>%
  step_intercept() %>% 
  prep(., strings_as_factors = FALSE) 

df_training_baked <- bake(xgb_dummies_recipe, df_training)
df_validation_baked <- bake(xgb_dummies_recipe, df_validation)

df_all_baked <- bind_rows(df_training_baked, df_validation_baked) %>% 
  mutate(
    row_id = row_number()
  )

v_training <- df_all_baked %>% 
  filter(qtr == TRAINING_START_QTR) %>% 
  pull(row_id)

v_validation <- df_all_baked %>% 
  filter(qtr == TRAINING_END_QTR) %>% 
  pull(row_id)

df_split <- validation_split(df_all_baked)

splits_plucked <- df_split %>% 
  pull(splits) %>% 
  pluck(1)

splits_plucked$in_id <- v_training
splits_plucked$out_id <- v_validation

df_split$splits[[1]] <- splits_plucked

f_xgb <- df_training_baked %>% 
  select(contains("attr")) %>% 
  names() %>% 
  reformulate(termlabels = ., response = "t_default")

xgb_spec <- boost_tree(
  trees = tune(),
  tree_depth = tune(), 
  min_n = tune(), 
  loss_reduction = tune(),                    
  sample_size = tune(), 
  mtry = tune(),         
  learn_rate = tune(),
  stop_iter = tune()
) %>% 
  set_engine("xgboost",
             event_level = "second",
             validation = .2) %>% 
  set_mode("classification")

xgb_grid <- read_csv("../../data/0000_0009_TUNING_GRID_100_grid.csv")

xgb_wf <- workflow() %>%
  add_formula(f_xgb) %>%
  add_model(xgb_spec)

cl <- makeCluster(3)
registerDoParallel(cl)

tic(msg = "Tuning Grid")
xgb_res <- tune_grid(xgb_wf,
                     grid = xgb_grid,
                     resamples = df_split,
                     control = control_grid(save_pred = TRUE, verbose = TRUE))
toc(log = TRUE)
stopCluster(cl)

best_model <- select_best(xgb_res, MODEL_METRIC)

write_csv(best_model, paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX,
                             "/final_model_xgb_", TRAINING_SUFFIX, "_",
                             RAND_NO_CID_SMALLEST_LARGEST, "_", MODEL_METRIC, 
                             "/df_optimal_hyperparameters.csv"))

df_timing_log <- tic.log() %>% 
  unlist() %>%
  tibble(Process = str_extract(., pattern = ".+(?=:)"),
         Time = str_extract(., pattern = "(?<=: ).+")) %>% 
  select(Process, Time)

tic.clear()
tic.clearlog()

write_csv(df_timing_log, paste0("../../data/pipeline_outputs/", 
                                SPECIAL_SUFFIX, "/", "timing_xgb_",
                                TRAINING_SUFFIX, "_", RAND_NO_CID_SMALLEST_LARGEST, "_", MODEL_METRIC, 
                                "/", "hyperparameter_timing_tuning.csv"))




